{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## pack_padded_sequence\n",
    "- https://pytorch.org/docs/1.0.1/nn.html#torch.nn.utils.rnn.pack_padded_sequence\n",
    "- 在使用深度学习特别是LSTM进行文本分析时，经常会遇到文本长度不一样的情况，此时就需要对同一个batch中的不同文本使用padding的方式进行文本长度对齐，方便将训练数据输入到LSTM模型进行训练，同时为了保证模型训练的精度，应该同时告诉LSTM相关padding的情况，此时，pytorch中的pack_padded_sequence就有了用武之地。\n",
    "- 通常pading的位置向量都是0，我们需要使用pack_padded_sequence() 把数据压紧，即去掉pading的部分，减少冗余。然后再输入网络中，如lstm等。通过网络后的结果也是压紧的，需要通过pad_packed_sequence()还原。\n",
    "- torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False)\n",
    "```\n",
    "input (Tensor) – padded batch of variable length sequences.\n",
    "lengths (Tensor) – list of sequences lengths of each batch element.\n",
    "batch_first (bool, optional) – if True, the input is expected in B x T x * format.\n",
    "```\n",
    "- pad_packed_sequence 解压"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.utils as utils\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0.1\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个双向lstm网络层\n",
    "lstm = nn.LSTM(4, 100, num_layers=1, batch_first=True, bidirectional=True)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 4])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 定义一个有padding的序列数据，也就是有冗余的0\n",
    "x = torch.tensor([[[1,2,3,4],\n",
    "                   [2,3,4,5],\n",
    "                   [2,5,6,0]],\n",
    "                  [[1,2,1,1],\n",
    "                   [1,6,7,9],\n",
    "                   [0,0,0,0]],\n",
    "                  [[1,2,3,4],\n",
    "                   [1,1,1,1],\n",
    "                   [0,0,0,0]],\n",
    "                  [[1,2,3,4],\n",
    "                   [0,0,0,0],\n",
    "                   [0,0,0,0]],\n",
    "                 ])\n",
    "x = x.float()\n",
    "x.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PackedSequence(data=tensor([[1., 2., 3., 4.],\n",
       "        [1., 2., 1., 1.],\n",
       "        [1., 2., 3., 4.],\n",
       "        [1., 2., 3., 4.],\n",
       "        [2., 3., 4., 5.],\n",
       "        [1., 6., 7., 9.],\n",
       "        [1., 1., 1., 1.],\n",
       "        [2., 5., 6., 0.]]), batch_sizes=tensor([4, 3, 1]))"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 压紧数据，去掉冗余\n",
    "packed = pack_padded_sequence(x, torch.tensor([3, 2, 2,1]), batch_first=True)   # 打包，压缩\n",
    "packed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PackedSequence(data=tensor([[-0.1277, -0.0269, -0.0070,  ..., -0.0997,  0.0827,  0.0156],\n",
       "        [-0.0272, -0.0503, -0.0239,  ..., -0.0778,  0.0617, -0.0288],\n",
       "        [-0.1277, -0.0269, -0.0070,  ..., -0.0498,  0.0694,  0.0405],\n",
       "        ...,\n",
       "        [-0.2541, -0.0296, -0.0304,  ..., -0.0526,  0.0918,  0.0056],\n",
       "        [-0.0754, -0.0565, -0.0133,  ..., -0.0139,  0.0202,  0.0313],\n",
       "        [-0.2212, -0.0844, -0.0038,  ..., -0.0590, -0.0370,  0.0686]],\n",
       "       grad_fn=<CatBackward>), batch_sizes=tensor([4, 3, 1]))"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 通过lstm进行计算\n",
    "output, hidden = lstm(packed)\n",
    "# 得到的结果也是压紧的\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 解压\n",
    "encoder_outputs, lenghts = pad_packed_sequence(output, batch_first=True)   # 解包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.1277, -0.0269, -0.0070,  ..., -0.0997,  0.0827,  0.0156],\n",
       "         [-0.1971, -0.0458, -0.0240,  ..., -0.0674,  0.0357,  0.0496],\n",
       "         [-0.2212, -0.0844, -0.0038,  ..., -0.0590, -0.0370,  0.0686]],\n",
       "\n",
       "        [[-0.0272, -0.0503, -0.0239,  ..., -0.0778,  0.0617, -0.0288],\n",
       "         [-0.2541, -0.0296, -0.0304,  ..., -0.0526,  0.0918,  0.0056],\n",
       "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       "\n",
       "        [[-0.1277, -0.0269, -0.0070,  ..., -0.0498,  0.0694,  0.0405],\n",
       "         [-0.0754, -0.0565, -0.0133,  ..., -0.0139,  0.0202,  0.0313],\n",
       "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       "\n",
       "        [[-0.1277, -0.0269, -0.0070,  ..., -0.0408,  0.0545,  0.0386],\n",
       "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],\n",
       "       grad_fn=<TransposeBackward0>)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 200])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder_outputs.size()   # size: [3, 3, 200] 双向的，所以是200维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 200])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder_outputs = encoder_outputs.contiguous()\n",
    "encoder_outputs.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 200])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature = encoder_outputs.view(-1, 200)\n",
    "feature.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 200])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "line = nn.Linear(200,200, bias=False)\n",
    "y = line(feature)\n",
    "y.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([3, 2, 2, 1])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lenghts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 来看下双向lstm的输出\n",
    "len(hidden) # 双向lstm 结果是2  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 100])"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h, c  = hidden\n",
    "h.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 100])"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 200])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inp = h.transpose(0, 1).contiguous().view(-1, 200)  # view就是reshape\n",
    "inp.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个单向lstm网络层\n",
    "lstm2 = nn.LSTM(4, 100, num_layers=1, batch_first=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "out, hid = lstm2(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(hid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 4, 100])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hid[0].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 有必要了解一下lstm的输出结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
